#define ARMA_DONT_PRINT_ERRORS

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;
using namespace arma;



//[[Rcpp:plugins(cpp11)]]
#include <omp.h>
//[[Rcpp:plugins(openmp)]]



const double log2pi = std::log(2.0 * M_PI);




double am_mult2(const arma::mat& rhs)
{
  arma::mat ip= rhs.t();
 return (ip(0));
}


arma::mat mvrnormArma(int n, arma::mat mu, arma::mat sigma,
		std::function<double()> rn = norm_rand ) {
         int ncols = sigma.n_cols;
	arma::mat Y(n,ncols);
        Y.imbue(rn);
     arma::mat Z=arma::repmat(mu, 1, n).t() + Y * arma::chol(sigma);     
     return (Z.t());
}


arma::mat defaultRNG(int n, arma::mat mu, arma::mat sigma) {
  return mvrnormArma(n, mu, sigma);
}


arma::mat serial(int n, const arma::mat mu, arma::mat sigma, const arma::vec& s) {
  std::normal_distribution<double> nor(0, 1);
std::mt19937 engine(
static_cast<uint64_t>(s[omp_get_thread_num()] ) );
  return mvrnormArma(n, mu, sigma, [&](){return nor(engine);});
}





arma::vec dmvnrm_arma(arma::mat x,  
                      arma::rowvec mean,  
                      arma::mat sigma, 
                      bool logd = false) { 
    int n = x.n_rows;
    int xdim = x.n_cols;
    arma::vec out(n);
    arma::mat rooti = arma::trans(arma::inv(trimatu(arma::chol(sigma))));
    double rootisum = arma::sum(log(rooti.diag()));
    double constants = -(static_cast<double>(xdim)/2.0) * log2pi;
    
    for (int i=0; i < n; ++i) {
        arma::vec z = rooti * arma::trans( x.row(i) - mean) ;    
        out(i)      = constants - 0.5 * arma::sum(z%z) + rootisum;     
    }  
      
    if (logd == false) {
        out = exp(out);
    }
    return(out);
}







double vm_convert(const arma::mat& rhs)
{
  arma::mat ip= rhs;
 return (ip(0));
}





// [[Rcpp::export()]]
Rcpp::List mcmc_logit_index (arma::mat y,
arma::mat t,
arma::mat X,
arma::mat I_country,
arma::mat I_party,
arma::mat Cov_Beta,
arma::mat beta2start,
arma::mat beta3start,
arma::vec sigmastart,
arma::mat reffect_countrystart,
arma::mat reffect_partystart,
arma::rowvec densitybeta,
arma::mat DiagWish, 
int mcmc = 100,
int burn=10,
int thin=10,
int chains=3
 ) {

 // Dimensions
int r= X.n_cols ;
int N=X.n_rows;
int n_country=I_country.n_cols;
int n_party=I_party.n_cols;


 // Current Containers

  arma::mat beta2 = beta2start ;
  arma::mat beta3 = beta3start ;
  arma::mat beta2new(r,1);
  beta2new.fill(0.0);
  arma::mat beta3new(r,1);
  beta3new.fill(0.0);
   
 arma::vec sigma=sigmastart;
 arma::vec sigmanew(2);
 sigmanew.fill(1.0) ;


arma::mat sigmabetadensity(r,r,arma::fill::eye);

 arma::mat reffect_country = reffect_countrystart ;

 arma::mat reffect_party = reffect_partystart ;
 

 arma::mat etaold_beta(N, 2) ;
 etaold_beta.fill(0.0) ;

 arma::mat lold_beta(N,1);
 lold_beta.fill(0.0) ;

 arma::mat etanew_beta(N, 3) ;
 etanew_beta.fill(0.0) ;

 arma::mat lnew_beta(N,1);
 lnew_beta.fill(0.0) ;


 arma::mat eta_sigma1(N, 2) ;
 eta_sigma1.fill(0.0) ;

 arma::mat lold_sigma1(N,1);
 lold_sigma1.fill(0.0) ;

 arma::mat lnew_sigma1(N,1);
 lnew_sigma1.fill(0.0) ;

 arma::mat eta_sigma2(N, 2) ;
 eta_sigma2.fill(0.0) ;

 arma::mat lold_sigma2(N,1);
 lold_sigma2.fill(0.0) ;

 arma::mat lnew_sigma2(N,1);
 lnew_sigma2.fill(0.0) ;

 arma::mat etaold_country(N, 2) ;
 etaold_country.fill(0.0) ;
 arma::mat etanew_country(N, 2) ;
 etanew_country.fill(0.0) ;


 arma::mat etaold_party(N, 2) ;
 etaold_party.fill(0.0) ;
 arma::mat etanew_party(N, 2) ;
 etanew_party.fill(0.0) ;


 arma::mat lold_country(N,1);
 lold_country.fill(0.0);
 arma::mat lnew_country(N,1);
 lnew_country.fill(0.0);


 arma::mat lold_party(N,1);
 lold_party.fill(0.0);
 arma::mat lnew_party(N,1);
 lnew_party.fill(0.0);


 arma::mat lcountry(n_country, 1) ;
 lcountry.fill(0.0) ;


 arma::mat lparty(n_party, 1) ;
 lparty.fill(0.0) ;


arma::mat reffect_country_new(n_country,2);
arma::mat reffect_party_new(n_party,2);


 arma::mat mu0(1, 2) ;
 mu0.fill(0.0) ;

double var=1.0;
arma::vec varscountry(2) ;
varscountry.fill(var) ;
 arma::mat sigma2country(2, 2) ;
 sigma2country.fill(0.0) ;
 sigma2country.diag() = varscountry ;

arma::vec varsparty(2) ;
varsparty.fill(var) ;
 arma::mat sigma2party(2, 2) ;
 sigma2party.fill(0.0) ;
 sigma2party.diag() = varsparty ;




 int lastit= (mcmc-burn)/thin	;

// Trace Containers
  arma::cube tracebeta2(r, lastit,chains) ;
  arma::cube tracebeta3(r, lastit,chains) ;
  arma::cube tracesigma(2, lastit, chains);
 arma::cube tracesigma2country(4, lastit,chains) ;
 arma::cube tracereffect_country(n_country*2, lastit,chains) ;
 arma::cube tracesigma2party(4, lastit,chains) ;
 arma::cube tracereffect_party(n_party*2, lastit,chains) ;

arma::vec seeds = linspace(0, 2, chains);




 
#pragma omp parallel num_threads(chains)
  {


std::mt19937 engine(
static_cast<uint64_t>(seeds[omp_get_thread_num()] ) );
    std::uniform_real_distribution<double> uni(0.0, 1.0);
     std::normal_distribution<double> nor(0, 1);

  std::random_device rd;
    std::mt19937 gen(rd());



#pragma omp for 
for (int chain = 0; chain < chains; ++chain) {

 // START SAMPLING
for (int iter = 0 ; iter < mcmc ; ++iter) {


etaold_beta.col(0)=X*beta2+I_country*reffect_country.col(0)+I_party*reffect_party.col(0);
etaold_beta.col(1)=X*beta3+I_country*reffect_country.col(1)+I_party*reffect_party.col(1);


for (int n = 0 ; n < N ; n++) {
if (y(n) == 1) {
lold_beta(n)=log(1-R::pnorm((log(t(n))-etaold_beta(n,0))/sigma(0),0.0,1,1,0))+log(1-R::pnorm((log(t(n))-etaold_beta(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 2) {
lold_beta(n)=log((1/sigma(0)*t(n))*R::dnorm((log(t(n))-etaold_beta(n,0))/sigma(0),0.0,1,0))+log(1-R::pnorm((log(t(n))-etaold_beta(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 3) {
lold_beta(n)=log(1-R::pnorm((log(t(n))-etaold_beta(n,0))/sigma(0),0.0,1,1,0))+log((1/sigma(1)*t(n))*R::dnorm((log(t(n))-etaold_beta(n,1))/sigma(1),0.0,1,0));
}
}


beta2new=mvrnormArma(1, beta2, Cov_Beta, [&](){return nor(engine);});
beta3new=mvrnormArma(1, beta3, Cov_Beta, [&](){return nor(engine);});


etanew_beta.col(0)=X*beta2new+I_country*reffect_country.col(0)+I_party*reffect_party.col(0);
etanew_beta.col(1)=X*beta3new+I_country*reffect_country.col(1)+I_party*reffect_party.col(1);



for (int n = 0 ; n < N ; n++) {
if (y(n) == 1) {
lnew_beta(n)=log(1-R::pnorm((log(t(n))-etanew_beta(n,0))/sigma(0),0.0,1,1,0))+log(1-R::pnorm((log(t(n))-etanew_beta(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 2) {
lnew_beta(n)=log((1/sigma(0)*t(n))*R::dnorm((log(t(n))-etanew_beta(n,0))/sigma(0),0.0,1,0))+log(1-R::pnorm((log(t(n))-etanew_beta(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 3) {
lnew_beta(n)=log(1-R::pnorm((log(t(n))-etanew_beta(n,0))/sigma(0),0.0,1,1,0))+log((1/sigma(1)*t(n))*R::dnorm((log(t(n))-etanew_beta(n,1))/sigma(1),0.0,1,0));
}
}


double accept_beta= vm_convert(sum(lnew_beta)-sum(lold_beta))+
(am_mult2(dmvnrm_arma(beta2new.t(),densitybeta, sigmabetadensity, true))-am_mult2(dmvnrm_arma(beta2.t(),densitybeta, sigmabetadensity, true)))+
(am_mult2(dmvnrm_arma(beta3new.t(),densitybeta, sigmabetadensity, true))-am_mult2(dmvnrm_arma(beta3.t(),densitybeta, sigmabetadensity, true)));


if(log(uni(engine))<accept_beta) {
for (int l=0; l<r; ++l) {
beta2(l)=beta2new(l);
beta3(l)=beta3new(l);
}
}



sigmanew(0)=sigma(0)+nor(engine)*0.1;


if (sigmanew(0)>0) {

eta_sigma1.col(0)=X*beta2+I_country*reffect_country.col(0)+I_party*reffect_party.col(0);
eta_sigma1.col(1)=X*beta3+I_country*reffect_country.col(1)+I_party*reffect_party.col(1);


for (int n = 0 ; n < N ; n++) {
if (y(n) == 1) {
lold_sigma1(n)=log(1-R::pnorm((log(t(n))-eta_sigma1(n,0))/sigma(0),0.0,1,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma1(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 2) {
lold_sigma1(n)=log((1/sigma(0)*t(n))*R::dnorm((log(t(n))-eta_sigma1(n,0))/sigma(0),0.0,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma1(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 3) {
lold_sigma1(n)=log(1-R::pnorm((log(t(n))-eta_sigma1(n,0))/sigma(0),0.0,1,1,0))+log((1/sigma(1)*t(n))*R::dnorm((log(t(n))-eta_sigma1(n,1))/sigma(1),0.0,1,0));
}
}

for (int n = 0 ; n < N ; n++) {
if (y(n) == 1) {
lnew_sigma1(n)=log(1-R::pnorm((log(t(n))-eta_sigma1(n,0))/sigmanew(0),0.0,1,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma1(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 2) {
lnew_sigma1(n)=log((1/sigmanew(0)*t(n))*R::dnorm((log(t(n))-eta_sigma1(n,0))/sigmanew(0),0.0,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma1(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 3) {
lnew_sigma1(n)=log(1-R::pnorm((log(t(n))-eta_sigma1(n,0))/sigmanew(0),0.0,1,1,0))+log((1/sigma(1)*t(n))*R::dnorm((log(t(n))-eta_sigma1(n,1))/sigma(1),0.0,1,0));
}
}



double accept_sigma1= vm_convert(sum(lnew_sigma1)-sum(lold_sigma1));

if(log(uni(engine))<accept_sigma1) {
sigma(0)=sigmanew(0);
}

}




sigmanew(1)=sigma(1)+nor(engine)*0.1;


if (sigmanew(1)>0) {

eta_sigma2.col(0)=X*beta2+I_country*reffect_country.col(0)+I_party*reffect_party.col(0);
eta_sigma2.col(1)=X*beta3+I_country*reffect_country.col(1)+I_party*reffect_party.col(1);


for (int n = 0 ; n < N ; n++) {
if (y(n) == 1) {
lold_sigma2(n)=log(1-R::pnorm((log(t(n))-eta_sigma2(n,0))/sigma(0),0.0,1,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma2(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 2) {
lold_sigma2(n)=log((1/sigma(0)*t(n))*R::dnorm((log(t(n))-eta_sigma2(n,0))/sigma(0),0.0,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma2(n,1))/sigma(1),0.0,1,1,0));
} else if (y(n) == 3) {
lold_sigma2(n)=log(1-R::pnorm((log(t(n))-eta_sigma2(n,0))/sigma(0),0.0,1,1,0))+log((1/sigma(1)*t(n))*R::dnorm((log(t(n))-eta_sigma2(n,1))/sigma(1),0.0,1,0));
}
}

for (int n = 0 ; n < N ; n++) {
if (y(n) == 1) {
lnew_sigma2(n)=log(1-R::pnorm((log(t(n))-eta_sigma2(n,0))/sigma(0),0.0,1,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma2(n,1))/sigmanew(1),0.0,1,1,0));
} else if (y(n) == 2) {
lnew_sigma2(n)=log((1/sigma(0)*t(n))*R::dnorm((log(t(n))-eta_sigma2(n,0))/sigma(0),0.0,1,0))+log(1-R::pnorm((log(t(n))-eta_sigma2(n,1))/sigmanew(1),0.0,1,1,0));
} else if (y(n) == 3) {
lnew_sigma2(n)=log(1-R::pnorm((log(t(n))-eta_sigma2(n,0))/sigma(0),0.0,1,1,0))+log((1/sigmanew(1)*t(n))*R::dnorm((log(t(n))-eta_sigma2(n,1))/sigmanew(1),0.0,1,0));
}
}




double accept_sigma2= vm_convert(sum(lnew_sigma2)-sum(lold_sigma2));

if(log(uni(engine))<accept_sigma2) {
sigma(1)=sigmanew(1);
}

}

 // TRACE
if ((iter+1)> burn & (iter+1)%thin==0) {
int j=((iter)-burn)/thin;
tracebeta2.subcube(0, j, chain, (r-1), j, chain)=beta2;
tracebeta3.subcube(0, j, chain, (r-1), j, chain)=beta3;
tracesigma.subcube(0, j, chain, 1, j, chain)=sigma;
tracereffect_country.subcube(0, j, chain, 2*n_country-1, j, chain)=vectorise(reffect_country);
tracesigma2country.subcube(0, j, chain, 3, j, chain)=vectorise(sigma2country);
tracereffect_party.subcube(0, j, chain, 2*n_party-1, j, chain)=vectorise(reffect_party);
tracesigma2party.subcube(0, j, chain, 3, j, chain)=vectorise(sigma2party);

} //end trace
}

}  // end iterations
} // 

// Returns
Rcpp::List ret;
 ret["beta2"] = tracebeta2 ;
 ret["beta3"] = tracebeta3 ;
ret["sigma"]=tracesigma;
ret["reffect_country"] = tracereffect_country;
ret["sigma2country"]=tracesigma2country;
ret["reffect_party"] = tracereffect_party;
ret["sigma2party"]=tracesigma2party;

 return(ret) ;
 

}


